from keras.layers import Input, Concatenate, Dense, Conv2D, Activation, Dropout, Flatten, BatchNormalization, AveragePooling2D
from keras.models import Model
from keras.optimizers import SGD
from keras.callbacks import ModelCheckpoint, CSVLogger

def conv2d_bn(input, nb_filters, kernel_size, padding = 'valid', data_format = 'channels_first', use_bias = True, dropout_rate = 0):
    x = Conv2D(nb_filters, kernel_size, 
        padding = padding,
        data_format = data_format,
        use_bias = use_bias,
        )(input)
    x = BatchNormalization(axis = 1 if data_format == 'channels_first' else -1)(x)
    x = Activation("relu")(x)
    if(dropout_rate > 0):
        x = Dropout(dropout_rate)(x)
    return x

def base_model(nb_classes):
    init = Input(shape=(59,34,4))
    x = conv2d_bn(init, 100, (5,2), dropout_rate = 0.5)
    x = conv2d_bn(x, 100, (5,2), dropout_rate = 0.5)
    x = conv2d_bn(x, 100, (5,2), dropout_rate = 0.5)
    x = Flatten()(x)
    x = Dense(300, activation = "relu")(x)
    output = Dense(units=nb_classes, activation="softmax")(x)
    model = Model(init, output)
    return model

def inception_model(nb_classes):
    init = Input(shape=(59,34,4))
    x = conv2d_bn(init, 64, (1,1), dropout_rate = 0.3)
    
    branch1 = conv2d_bn(x, 32, (3,1), padding = 'same', use_bias = False)
    branch2 = conv2d_bn(x, 32, (1,2), padding = 'same', use_bias = False)
    branch3 = conv2d_bn(x, 32, (5,2), padding = 'same', use_bias = False)
    x = Concatenate(axis = 1)([branch1, branch2, branch3])
    x = Dropout(0.3)(x)

    branch1 = conv2d_bn(x, 32, (3,1), padding = 'same', use_bias = False)
    branch2 = conv2d_bn(x, 32, (1,2), padding = 'same', use_bias = False)
    branch3 = conv2d_bn(x, 32, (5,2), padding = 'same', use_bias = False)
    x = Concatenate(axis = 1)([branch1, branch2, branch3])
    x = Dropout(0.3)(x)
    
    branch1 = conv2d_bn(x, 32, (3,1), padding = 'same', use_bias = False)
    branch2 = conv2d_bn(x, 32, (1,2), padding = 'same', use_bias = False)
    branch3 = conv2d_bn(x, 32, (5,2), padding = 'same', use_bias = False)
    x = Concatenate(axis = 1)([branch1, branch2, branch3])
    x = Dropout(0.3)(x)

    x = AveragePooling2D((1, 2), data_format = 'channels_first')(x)
    x = Dropout(0.2)(x)
    x = Flatten()(x)
    output = Dense(units = nb_classes, activation = 'softmax')(x)
    model = Model(init, output)
    return model